"""
This script computes the sum of weights of words occuring in comments (with their weights based on WordNet relations).
Weights are read from weighted_dictionary.json (derived from 01_wordnet_dict.py).
"""

import sqlite3
import os, time, re, sys
import json
import pandas as pd
import multiprocessing as mp
from functools import partial
import spacy

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from utils.config import load_config
from utils.logger import set_logger

LOG_PATH = '../../../logs/02_word_frequency.log'

# Load the configuration file
config = load_config()
logger = set_logger(LOG_PATH)

# Load the Spanish model in spacy after installing it with `python -m spacy download es_core_news_sm`
nlp = spacy.load('es_core_news_sm')

# Setting the working directory to the directory of the script
os.chdir(os.path.dirname(__file__))

# CONSTANTS
COMMENTS_TABLE = 'comments_chunk_1'
START_DATETIME = config['start_datetime']  # '2023-01-01T00:00:00Z'
END_DATETIME = config['end_datetime']  # '2023-12-31T23:59:59Z'
COUNTRY_NAME = config['country_name']  # e.g. "mexico"
COUNTRY_CODE = config['country_code']  # e.g. "mx"

# Convert the start and end datetimes to the format `YYYYMM`
START_YEAR_MONTH = re.sub(r'[^0-9]', '', START_DATETIME[:7])
END_YEAR_MONTH = re.sub(r'[^0-9]', '', END_DATETIME[:7])

# PATHS
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
WORDNET_DIR = os.path.join(PROJECT_ROOT, '../../../data/wordnet')
WEIGHTED_DICT_PATH = os.path.join(PROJECT_ROOT, '../../../output/weighted_dictionary_es.json')
OUTPUT_DIR = os.path.join(PROJECT_ROOT, '../../../output', COUNTRY_NAME)
VIDEO_DB_PATH = os.path.join(OUTPUT_DIR, f'{START_YEAR_MONTH[:4]}-{START_YEAR_MONTH[4:]}-collection/{START_YEAR_MONTH}_{END_YEAR_MONTH}_02_{COUNTRY_CODE}_youtube_data.db')
OUTPUT_CSV_FILE = os.path.join(OUTPUT_DIR, f'{START_YEAR_MONTH[:4]}-{START_YEAR_MONTH[4:]}-collection/{START_YEAR_MONTH}_{END_YEAR_MONTH}_05_{COUNTRY_CODE}_comments_scalar_score.csv')

# Function to read the weighted dictionary from a JSON file
def read_weighted_dict(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def lemmatize_and_sum(text, weighted_dict):
    text = text.lower()
    doc = nlp(text)
    lemmatized_tokens = [token.lemma_ for token in doc]
    word_count = {word: lemmatized_tokens.count(word) * weight for word, weight in weighted_dict.items() if word in lemmatized_tokens}
    return word_count

def process_record(record, weighted_dict):
    comment_id, video_id, text_original, author_display_name, published_at = record
    word_count = lemmatize_and_sum(text_original, weighted_dict)
    if word_count:  # Only return if word_count is not empty
        # calculate the scalar sum of the weights
        scalar_sum = sum(word_count.values())
        logger.info(f"Processed comment {comment_id} with scalar sum {scalar_sum}")
        return {
            "comment_id": comment_id,
            "video_id": video_id,
            "author_display_name": author_display_name,
            "published_at": published_at,
            "word_count": word_count,
            "scalar_sum": scalar_sum
        }
    return None

def fetch_records_from_db(db_path, chunk_table_name):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute(f"SELECT comment_id, video_id, text_original, author_display_name, published_at FROM {chunk_table_name}")
    records = cursor.fetchall()
    conn.close()
    return records

if __name__ == '__main__':
    start_time = time.time()
    # Read the weighted dictionary
    weighted_dict = read_weighted_dict(WEIGHTED_DICT_PATH)

    # Fetch records from the database
    records = fetch_records_from_db(VIDEO_DB_PATH, COMMENTS_TABLE)
    logger.info(f"Total records fetched: {len(records)}")
    if len(records) == 0:
        logger.warning("No records found in the database.")
        sys.exit(1)
    logger.info("Processing records...")

    # Use functools.partial to pass weighted_dict to process_record
    process_record_with_dict = partial(process_record, weighted_dict=weighted_dict)

    # Use multiprocessing to process records in parallel
    with mp.Pool(mp.cpu_count()) as pool:
        results = pool.map(process_record_with_dict, records)

    # Filter out None values and convert to DataFrame
    filtered_results = [result for result in results if result is not None]
    df = pd.DataFrame(filtered_results)

    # Save the DataFrame to a CSV file
    logger.info(f"Writing to {OUTPUT_CSV_FILE}...")
    df.to_csv(OUTPUT_CSV_FILE, index=False)
    logger.info("Done!")
    logger.info(f"Time taken: {time.time() - start_time:.2f} seconds")